本文介绍了TensorFlow的程序结构和编程相关的知识。

Introduction to Computing Graph

  • TensorFlow中程序都是以计算图(computational graph)的形式来进行的,大致可以包含两部分:构造计算图运行计算图
  • 计算图可以看作是一系列运算操作的静态集合,其规定了计算的流程。注意计算图是静态的,只构造好计算图是无法计算出结果的,必须运行计算图才可以得到结果(使用tf.Session)。

构造计算图

  • 计算图由两种成分组成:tf.Tensortf.Operationtf.Operation可以看作是图中的节点,tf.Tensor可看作是图中的边。

  • 如何构造计算图:

    1. 调用相关方法,生成tf.Operation (node) 和 tf.Tensor (edge) 对象
    2. 将它们添加到tf.Graph对象中。
    • 注:实际上tensorflow会自动将它们添加到default graph中,许多程序也仅仅使用default graph,因此第二步就可以省略。但如果我们要使用多个计算图,就需要手动将operation和tensor添加到相应的graph中。

multiple graphs programming

  • 大部分情况下一个graph就足够了,但如果有时候一个graph太复杂,或者有太多无关的操作,可以另开一个graph来计算。
  • 具体步骤:
    1. 通过tf.Graph()函数来生成graph;
    2. 通过with another_graph.as_default():语句块来切换当前的默认graph,在这个语句块内的所有操作都会添加到默认的计算图中;
    3. 可通过tf.get_default_graph()来获取当前默认的计算图,返回tf.Graph()对象。

运行计算图

创造session

  • 计算图的运行是通过tf.Session来进行的,其可以理解为计算图的运行机制,负责执行运算图所定义的tensor或operation。
  • 要执行计算图首先需要创造一个session,因为session对象使用后需要释放(因为session会占有一定物理资源如GPU、网络连接等),所以建议用with语句块来创造session对象,这样语句块结束后可以自动释放资源。
    1
    2
    with tf.Session() as sess:
    # code relavant to sess

初始化session

  • session的构造函数有三个optional参数,可以让我们对session进行一些设置__init__(target = '', graph = None, config = None);
  • target:执行session的设备,常用于分布式计算情境下。默认是在当前本地设备上执行。
  • graph:session所执行的计算图,默认是default graph,如果只使用一个计算图不用管这个参数。
  • config:session的配置参数,通过传入tf.ConfigProto对象来对session进行设置。ConfigProto常用参数:
    • allow_soft_placement:是否允许设备代替,指定在GPU上运行的代码,如果GPU不存在,会被替换到CPU上执行。通常设为true
    • log_device_placement:是否打印设备的分配信息,通常设为flase.

执行session

  • session可以执行的有两种类型,一种是tensor,一种是operaton;若执行的是tensor,则会返回tensor所对应的值(type:numpy.ndarray);若执行的是operation,则方法只是执行相应的操作,不会返回值,如initializetrain操作。
  • session的执行也有两种方法:
    • 一种是sess.run(tensor/operation)
    • 另一种是调用tf.Operation.run()/tf.Tensor.eval()函数:op.run()实际是tf.get_default_session().run(op)的简写,tensor.eval()也是tf.get_default_session().run(tensor)的简写,因此这两个方法执行的前需要确保正确设置了default session(上一种方法因为已指定了session,也就不需要default session了)。
    • 而如何设置default session,主要通过with语句块:
      1. with tf.Session():,生成session的同时将其自动设置为了default session(当然出语句块也就清除掉了);
      2. with sess.as_default():,将已有的某个session设置为default session。as_default()方法会返回一个context manager,对应的session在本语句块内会被自动设置为default session。
      • 注:as_default()方法出语句块时只会清除default session设置,并不释放掉session(因为这个方法本就是为session的重用而生的),因此如果session不用了要手动close掉。
        1
        2
        3
        4
        sess = tf.Session()
        print(sess.run(变量名)) # sess.run会自动检测计算这个变量所需要的依赖,之后执行计算图,计算出结果并输出。

        print(sess.run({"变量1": 变量1, "变量2": 变量2})) # sess.run可以一次传入多个变量,但要以字典的形式传入(即键值对的形式)

构造Layer

  • 计算图的一个重要构成部分就是layer,也就是神经网络中的layer,在layer中封装了可以被训练的变量(如weights, biases等)以及作用于他们的操作(各种优化算法),是神经网络训练的核心。
  • 关于layer的使用,把它理解成一个函数即可,本质就是:输入数据->经过一定运算->输出数据的过程。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    # 构造一个densely-connected layer
    sess = tf.Session()
    x = tf.placeholder(tf.float32, shape=[None, 3])
    linear_model = tf.layers.Dense(units=1) # units:dimensionality of output unit??
    y = linear_model(x)

    # 必须先对layer中的参数初始化(weight matrix 和 bias)
    init = tf.global_variables_initializer()
    sess.run(init)

    print(sess.run(y, feed_dict={x: [[1, 2, 3], [1, 2, 3]]}))

在不同的设备上执行operation

通过with tf.device():语句:eg. with tf.device("/device:CPU:0"):

TensorFlow中的with语句

在TensorFlow中可以用with语句块进行资源管理的类:
with tf.Graph():
with tf.Session():
with tf.device()

TensorFlow训练实例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# inputs and labels
x = tf.constant([[1], [2], [3], [4]], dtype=tf.float32)
y_true = tf.constant([[0], [-1], [-2], [-3]], dtype=tf.float32)

# define the model
linear_model = tf.layers.Dense(units=1)

# define prediction
y_pred = linear_model(x)

# define loss function
loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)

# define optimizer
optimizer = tf.train.GradientDescentOptimizer(0.01)

# define train operation
train = optimizer.minimize(loss)

# train the model (run train operation in session), iterate 100 times
for i in range(100):
_, loss_value = sess.run((train, loss))
print(loss_value)

Post Date: 2019-03-02

版权声明: 本文为原创文章,转载请注明出处